{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"A very simple MNIST classifier.\n",
    "See extensive documentation at\n",
    "https://www.tensorflow.org/get_started/mnist/beginners\n",
    "\"\"\"\n",
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import sys\n",
    "\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "FLAGS = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_NODE = 784  #输入层的节点数，图片为28*28，为图片的像素\n",
    "OUTPUT_NODE = 10   #输出层的节点数，等于类别的数目，需要区分0-9，所以为10类\n",
    "TRAINING_STEPS = 30000 #训练轮数\n",
    "BATCH_SIZE = 50 #一个训练batch中的训练数据个数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们在这里调用系统提供的Mnist数据函数为我们读入数据，如果没有下载的话则进行下载。\n",
    "\n",
    "<font color=#ff0000>**这里将data_dir改为适合你的运行环境的目录**</font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-3-cf0ad10d35ed>:3: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From D:\\python\\python3.7\\envs\\tensorflow_gpu\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From D:\\python\\python3.7\\envs\\tensorflow_gpu\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\python\\python3.7\\envs\\tensorflow_gpu\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\python\\python3.7\\envs\\tensorflow_gpu\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From D:\\python\\python3.7\\envs\\tensorflow_gpu\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
     ]
    }
   ],
   "source": [
    "# Import data\n",
    "data_dir = 'MNIST_data/'\n",
    "mnist = input_data.read_data_sets(data_dir, one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#x为训练数据的占位符，y_为训练标签的占位符\n",
    "x = tf.placeholder(tf.float32, [None, INPUT_NODE])\n",
    "y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#将单张图片从784维还原为28X28的矩阵\n",
    "x_image = tf.reshape(x,[-1,28,28,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weight_variable(shape):\n",
    "    initial = tf.truncated_normal(shape,stddev=0.1)\n",
    "    return tf.Variable(initial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bias_variable(shape):\n",
    "    initial = tf.constant(0.1,shape=shape)\n",
    "    return tf.Variable(initial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv2d(x,w):\n",
    "    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_pool_2x2(x):\n",
    "    return tf.nn.max_pool(x,ksize = [1,2,2,1],strides = [1,2,2,1],padding = 'SAME')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#第一层卷积层\n",
    "W_conv1 = weight_variable([5,5,1,32])\n",
    "b_conv1 = bias_variable([32])\n",
    "h_conv1 = tf.nn.relu(conv2d(x_image,W_conv1) + b_conv1)\n",
    "h_pool1 = max_pool_2x2(h_conv1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#第二层卷积层\n",
    "W_conv2 = weight_variable([5,5,32,64])\n",
    "b_conv2 = bias_variable([64])\n",
    "h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2)\n",
    "h_pool2 = max_pool_2x2(h_conv2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "#全连接层，输出为1024\n",
    "W_fc1 = weight_variable([7*7*64,1024])\n",
    "b_fc1  = bias_variable([1024])\n",
    "h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])\n",
    "h = tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1 ) + b_fc1 )\n",
    "\n",
    "#设定Dropout为0.5\n",
    "keep_prob = tf.placeholder(tf.float32)\n",
    "h_drop = tf.nn.dropout(h,keep_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "#把1024维转换成10维\n",
    "W_fc2 = weight_variable([1024,OUTPUT_NODE])\n",
    "b_fc2  = bias_variable([OUTPUT_NODE])\n",
    "y = tf.matmul(h_drop,W_fc2) + b_fc2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们计算交叉熵，注意这里不要使用注释中的手动计算方式，而是使用系统函数。\n",
    "另一个注意点就是，softmax_cross_entropy_with_logits的logits参数是**未经激活的wx+b**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-14-bf86c3447efc>:11: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# The raw formulation of cross-entropy,\n",
    "#\n",
    "#   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),\n",
    "#                                 reduction_indices=[1]))\n",
    "#\n",
    "# can be numerically unstable.\n",
    "#\n",
    "# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw\n",
    "# outputs of 'y', and then average across the batch.\n",
    "cross_entropy = tf.reduce_mean(\n",
    "    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成一个训练step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_step = tf.train.GradientDescentOptimizer(0.001).minimize(cross_entropy)\n",
    "\n",
    "# Test trained model\n",
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0,trainning accuracy 0.02\n",
      "step 1000,trainning accuracy 0.94\n",
      "step 2000,trainning accuracy 0.96\n",
      "step 3000,trainning accuracy 0.98\n",
      "step 4000,trainning accuracy 0.94\n",
      "step 5000,trainning accuracy 0.9\n",
      "step 6000,trainning accuracy 0.92\n",
      "step 7000,trainning accuracy 0.98\n",
      "step 8000,trainning accuracy 0.96\n",
      "step 9000,trainning accuracy 1\n",
      "step 10000,trainning accuracy 0.96\n",
      "step 11000,trainning accuracy 0.94\n",
      "step 12000,trainning accuracy 0.98\n",
      "step 13000,trainning accuracy 1\n",
      "step 14000,trainning accuracy 0.98\n",
      "step 15000,trainning accuracy 0.96\n",
      "step 16000,trainning accuracy 0.92\n",
      "step 17000,trainning accuracy 0.94\n",
      "step 18000,trainning accuracy 0.98\n",
      "step 19000,trainning accuracy 1\n",
      "step 20000,trainning accuracy 0.96\n",
      "step 21000,trainning accuracy 0.98\n",
      "step 22000,trainning accuracy 1\n",
      "step 23000,trainning accuracy 0.92\n",
      "step 24000,trainning accuracy 0.98\n",
      "step 25000,trainning accuracy 0.98\n",
      "step 26000,trainning accuracy 0.96\n",
      "step 27000,trainning accuracy 0.96\n",
      "step 28000,trainning accuracy 1\n",
      "step 29000,trainning accuracy 0.96\n",
      "test accuracy 0.9785\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run() #参数初始化\n",
    "\n",
    "    # Train\n",
    "    for i in range(TRAINING_STEPS):\n",
    "        batch = mnist.train.next_batch(BATCH_SIZE)\n",
    "        if i%1000==0 :        \n",
    "            train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})        \n",
    "            print(\"step %d,trainning accuracy %g\" % (i,train_accuracy))\n",
    "        train_step.run(feed_dict = {x:batch[0],y_:batch[1],keep_prob:0.5})\n",
    "    print(\"test accuracy %g\" % accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TRAINING_STEPS训练轮数调整：3000发现巡检的准确率太低，调整到20000仍没有达到理解效果且不稳定，最后改到30000\n",
    "### BATCH_SIZ个训练batch数：从100调整到50\n",
    "### 学习率：从0.1调整到0.001\n",
    "### 第一层卷积层：参考网站资料，选用[5,5,1,32]的卷积（length*width*n*channel），卷积后选 用Relu作为激活函数\n",
    "### 第二层卷积层：选用[5,5,32,64]的卷积，卷积后选用Relu作为激活函数\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "验证我们模型在测试数据上的准确率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "毫无疑问，这个模型是一个非常简陋，性能也不理想的模型。目前只能达到92%左右的准确率。\n",
    "接下来，希望大家利用现有的知识，将这个模型优化至98%以上的准确率。\n",
    "Hint：\n",
    "- 卷积\n",
    "- 池化\n",
    "- 激活函数\n",
    "- 正则化\n",
    "- 初始化\n",
    "- 摸索一下各个超参数\n",
    "  - 卷积kernel size\n",
    "  - 卷积kernel 数量\n",
    "  - 学习率\n",
    "  - 正则化惩罚因子\n",
    "  - 最好每隔几个step就对loss、accuracy等等进行一次输出，这样才能有根据地进行调整"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python tensorflow_gpu",
   "language": "python",
   "name": "tensorflow_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
